/* * Copyright (c) 2010 Google Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ package com.google.api.services.samples.prediction.cmdline; import com.google.api.client.googleapis.auth.oauth2.GoogleCredential; import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport; import com.google.api.client.http.HttpResponse; import com.google.api.client.http.HttpResponseException; import com.google.api.client.http.HttpTransport; import com.google.api.client.json.JsonFactory; import com.google.api.client.json.jackson2.JacksonFactory; import com.google.api.services.prediction.Prediction; import com.google.api.services.prediction.PredictionScopes; import com.google.api.services.prediction.model.Input; import com.google.api.services.prediction.model.Input.InputInput; import com.google.api.services.prediction.model.Insert; import com.google.api.services.prediction.model.Insert2; import com.google.api.services.prediction.model.Output; import com.google.api.services.storage.StorageScopes; import java.io.IOException; import java.io.File; import java.util.Arrays; import java.util.Collections; /** * @author Yaniv Inbar */ public class PredictionSample { /** * Be sure to specify the name of your application. If the application name is {@code null} or * blank, the application will log a warning. Suggested format is "MyCompany-ProductName/1.0". */ private static final String APPLICATION_NAME = "HelloPrediction"; /** Specify the Cloud Storage location of the training data. */ static final String STORAGE_DATA_LOCATION = "your_bucket/language_id.txt"; static final String MODEL_ID = "languageidentifier"; /** * Specify your Google Developers Console project ID, your service account's email address, and * the name of the P12 file you copied to src/main/resources/. */ static final String PROJECT_ID = "your-project-1234"; static final String SERVICE_ACCT_EMAIL = "account123@your-project-1234.iam.gserviceaccount.com"; static final String SERVICE_ACCT_KEYFILE = "YourProject-123456789abc.p12"; /** Global instance of the HTTP transport. */ private static HttpTransport httpTransport; /** Global instance of the JSON factory. */ private static final JsonFactory JSON_FACTORY = JacksonFactory.getDefaultInstance(); /** Authorizes the installed application to access user's protected data. */ private static GoogleCredential authorize() throws Exception { return new GoogleCredential.Builder() .setTransport(httpTransport) .setJsonFactory(JSON_FACTORY) .setServiceAccountId(SERVICE_ACCT_EMAIL) .setServiceAccountPrivateKeyFromP12File(new File( PredictionSample.class.getResource("/"+SERVICE_ACCT_KEYFILE).getFile())) .setServiceAccountScopes(Arrays.asList(PredictionScopes.PREDICTION, StorageScopes.DEVSTORAGE_READ_ONLY)) .build(); } private static void run() throws Exception { httpTransport = GoogleNetHttpTransport.newTrustedTransport(); // authorization GoogleCredential credential = authorize(); Prediction prediction = new Prediction.Builder( httpTransport, JSON_FACTORY, credential).setApplicationName(APPLICATION_NAME).build(); train(prediction); predict(prediction, "Is this sentence in English?"); predict(prediction, "¿Es esta frase en Español?"); predict(prediction, "Est-ce cette phrase en Français?"); } private static void train(Prediction prediction) throws IOException { Insert trainingData = new Insert(); trainingData.setId(MODEL_ID); trainingData.setStorageDataLocation(STORAGE_DATA_LOCATION); prediction.trainedmodels().insert(PROJECT_ID, trainingData).execute(); System.out.println("Training started."); System.out.print("Waiting for training to complete"); System.out.flush(); int triesCounter = 0; Insert2 trainingModel; while (triesCounter < 100) { // NOTE: if model not found, it will throw an HttpResponseException with a 404 error try { HttpResponse response = prediction.trainedmodels().get(PROJECT_ID, MODEL_ID).executeUnparsed(); if (response.getStatusCode() == 200) { trainingModel = response.parseAs(Insert2.class); String trainingStatus = trainingModel.getTrainingStatus(); if (trainingStatus.equals("DONE")) { System.out.println(); System.out.println("Training completed."); System.out.println(trainingModel.getModelInfo()); return; } } response.ignore(); } catch (HttpResponseException e) { } try { // 5 seconds times the tries counter Thread.sleep(5000 * (triesCounter + 1)); } catch (InterruptedException e) { break; } System.out.print("."); System.out.flush(); triesCounter++; } error("ERROR: training not completed."); } private static void error(String errorMessage) { System.err.println(); System.err.println(errorMessage); System.exit(1); } private static void predict(Prediction prediction, String text) throws IOException { Input input = new Input(); InputInput inputInput = new InputInput(); inputInput.setCsvInstance(Collections.<Object>singletonList(text)); input.setInput(inputInput); Output output = prediction.trainedmodels().predict(PROJECT_ID, MODEL_ID, input).execute(); System.out.println("Text: " + text); System.out.println("Predicted language: " + output.getOutputLabel()); } public static void main(String[] args) { try { run(); // success! return; } catch (IOException e) { System.err.println(e.getMessage()); } catch (Throwable t) { t.printStackTrace(); } System.exit(1); } }